{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6dbdc74f",
   "metadata": {},
   "source": [
    "## 5.2 多层感知机代码实现"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03a88fe0",
   "metadata": {},
   "source": [
    "在 PyTorch 中实现多层感知机（MLP）可以分为以下几步：\n",
    "\n",
    "1. 导入所需的库和模块，包括 PyTorch 的 nn 模块，torch.optim 模块和数据加载及预处理的常用库，例如 torchvision 和 torchtext。\n",
    "2. 定义多层感知机模型。这可以通过继承 PyTorch 的 nn.Module 类并定义前向传播函数来完成。\n",
    "3. 加载数据集。这可以使用 PyTorch 提供的数据加载器或自定义加载器完成。\n",
    "4. 定义损失函数和优化器。这可以使用 PyTorch 的内置函数和优化器来完成。\n",
    "5. 开始训练模型。在训练循环中，您需要通过获取输入和标签，计算模型输出，计算损失并更新模型参数来训练模型。\n",
    "6. 在训练之后，可以使用模型进行推理或将其保存以供将来使用。\n",
    "\n",
    "下面这个例子中，梗直哥给你演示了一个包含两个隐藏层的 MLP实现，使用 MNIST 数据集进行训练，一起来感受一下整个过程："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "807a38b7",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "\n",
    "# 定义 MLP 网络\n",
    "class MLP(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, num_classes):\n",
    "        super(MLP, self).__init__()\n",
    "        self.fc1 = nn.Linear(input_size, hidden_size)\n",
    "        self.relu = nn.ReLU()\n",
    "        self.fc2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.fc3 = nn.Linear(hidden_size, num_classes)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        out = self.fc1(x)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc2(out)\n",
    "        out = self.relu(out)\n",
    "        out = self.fc3(out)\n",
    "        return out\n",
    "\n",
    "# 定义超参数\n",
    "input_size = 28 * 28  # 输入大小\n",
    "hidden_size = 512  # 隐藏层大小\n",
    "num_classes = 10  # 输出大小（类别数）\n",
    "batch_size = 100  # 批大小\n",
    "learning_rate = 0.001  # 学习率\n",
    "num_epochs = 10  # 训练轮数\n",
    "\n",
    "# 加载 MNIST 数据集\n",
    "train_dataset = datasets.MNIST(root='../data/mnist', train=True, transform=transforms.ToTensor(), download=True)\n",
    "test_dataset = datasets.MNIST(root='../data/mnist', train=False, transform=transforms.ToTensor(), download=True)\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)\n",
    "test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)\n",
    "\n",
    "# 实例化 MLP 网络\n",
    "model = MLP(input_size, hidden_size, num_classes)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0a980fee",
   "metadata": {},
   "source": [
    "现在我们已经定义了 MLP 网络并加载了 MNIST 数据集，接下来使用 PyTorch 的自动求导功能和优化器进行训练。首先，定义损失函数和优化器；然后迭代训练数据并使用优化器更新网络参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "bf50be51",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [1/10], Step [100/600], Loss: 0.0419\n",
      "Epoch [1/10], Step [200/600], Loss: 0.0931\n",
      "Epoch [1/10], Step [300/600], Loss: 0.0609\n",
      "Epoch [1/10], Step [400/600], Loss: 0.0482\n",
      "Epoch [1/10], Step [500/600], Loss: 0.1138\n",
      "Epoch [1/10], Step [600/600], Loss: 0.0533\n",
      "Epoch [2/10], Step [100/600], Loss: 0.0340\n",
      "Epoch [2/10], Step [200/600], Loss: 0.0619\n",
      "Epoch [2/10], Step [300/600], Loss: 0.2061\n",
      "Epoch [2/10], Step [400/600], Loss: 0.0695\n",
      "Epoch [2/10], Step [500/600], Loss: 0.0269\n",
      "Epoch [2/10], Step [600/600], Loss: 0.0330\n",
      "Epoch [3/10], Step [100/600], Loss: 0.0135\n",
      "Epoch [3/10], Step [200/600], Loss: 0.0710\n",
      "Epoch [3/10], Step [300/600], Loss: 0.0089\n",
      "Epoch [3/10], Step [400/600], Loss: 0.0139\n",
      "Epoch [3/10], Step [500/600], Loss: 0.0786\n",
      "Epoch [3/10], Step [600/600], Loss: 0.0331\n",
      "Epoch [4/10], Step [100/600], Loss: 0.0072\n",
      "Epoch [4/10], Step [200/600], Loss: 0.0183\n",
      "Epoch [4/10], Step [300/600], Loss: 0.0291\n",
      "Epoch [4/10], Step [400/600], Loss: 0.0399\n",
      "Epoch [4/10], Step [500/600], Loss: 0.0065\n",
      "Epoch [4/10], Step [600/600], Loss: 0.0306\n",
      "Epoch [5/10], Step [100/600], Loss: 0.0097\n",
      "Epoch [5/10], Step [200/600], Loss: 0.0073\n",
      "Epoch [5/10], Step [300/600], Loss: 0.0327\n",
      "Epoch [5/10], Step [400/600], Loss: 0.0027\n",
      "Epoch [5/10], Step [500/600], Loss: 0.0254\n",
      "Epoch [5/10], Step [600/600], Loss: 0.0136\n",
      "Epoch [6/10], Step [100/600], Loss: 0.0195\n",
      "Epoch [6/10], Step [200/600], Loss: 0.0124\n",
      "Epoch [6/10], Step [300/600], Loss: 0.0065\n",
      "Epoch [6/10], Step [400/600], Loss: 0.0975\n",
      "Epoch [6/10], Step [500/600], Loss: 0.0333\n",
      "Epoch [6/10], Step [600/600], Loss: 0.0346\n",
      "Epoch [7/10], Step [100/600], Loss: 0.0055\n",
      "Epoch [7/10], Step [200/600], Loss: 0.0003\n",
      "Epoch [7/10], Step [300/600], Loss: 0.0014\n",
      "Epoch [7/10], Step [400/600], Loss: 0.0052\n",
      "Epoch [7/10], Step [500/600], Loss: 0.0592\n",
      "Epoch [7/10], Step [600/600], Loss: 0.0139\n",
      "Epoch [8/10], Step [100/600], Loss: 0.0196\n",
      "Epoch [8/10], Step [200/600], Loss: 0.0122\n",
      "Epoch [8/10], Step [300/600], Loss: 0.0211\n",
      "Epoch [8/10], Step [400/600], Loss: 0.0009\n",
      "Epoch [8/10], Step [500/600], Loss: 0.0464\n",
      "Epoch [8/10], Step [600/600], Loss: 0.0207\n",
      "Epoch [9/10], Step [100/600], Loss: 0.0078\n",
      "Epoch [9/10], Step [200/600], Loss: 0.0047\n",
      "Epoch [9/10], Step [300/600], Loss: 0.0029\n",
      "Epoch [9/10], Step [400/600], Loss: 0.0047\n",
      "Epoch [9/10], Step [500/600], Loss: 0.0724\n",
      "Epoch [9/10], Step [600/600], Loss: 0.0219\n",
      "Epoch [10/10], Step [100/600], Loss: 0.0008\n",
      "Epoch [10/10], Step [200/600], Loss: 0.0054\n",
      "Epoch [10/10], Step [300/600], Loss: 0.0015\n",
      "Epoch [10/10], Step [400/600], Loss: 0.0029\n",
      "Epoch [10/10], Step [500/600], Loss: 0.0043\n",
      "Epoch [10/10], Step [600/600], Loss: 0.0025\n"
     ]
    }
   ],
   "source": [
    "# 定义损失函数和优化器\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.Adam(model.parameters(), lr=learning_rate)\n",
    "\n",
    "# 训练网络\n",
    "for epoch in range(num_epochs):\n",
    "    for i, (images, labels) in enumerate(train_loader):\n",
    "        images = images.reshape(-1, 28 * 28)\n",
    "        outputs = model(images)\n",
    "        loss = criterion(outputs, labels)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if (i + 1) % 100 == 0:\n",
    "            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item():.4f}')\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "01fe79f6",
   "metadata": {},
   "source": [
    "最后，我们可以在测试数据上评估模型的准确率："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "47c6ecaf",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy of the network on the 10000 test images: 97.88 %\n"
     ]
    }
   ],
   "source": [
    "# 测试网络\n",
    "with torch.no_grad():\n",
    "    correct = 0\n",
    "    total = 0\n",
    "    for images, labels in test_loader:\n",
    "        images = images.reshape(-1, 28 * 28)\n",
    "        outputs = model(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "    print(f'Accuracy of the network on the 10000 test images: {100 * correct / total} %')\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "342e2d3a",
   "metadata": {},
   "source": [
    "可以看到训练效果还不错，准确率97.88%。\n",
    "\n",
    "**梗直哥建议：我们这节课言简意赅的讲解了一个例子，主要目的是突出代码实现。如果你在理解方面感觉自己有问题，可以稍微分析一下原因。如果代码看不懂是python的问题，可以考虑补充这方面知识。如果对神经网络原理还希望了解更多，可以选修哥[《机器学习必修课：python实战》](https://appmixy0usl5902.h5.xiaoeknow.com/)中神经网络相关章节内容。如果是运行或调参经验缺乏，欢迎入群讨论（微信：gengzhige99）**"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a1c28af1",
   "metadata": {},
   "source": [
    "[Next 5-3前向传播和反向传播](./5-3%20前向传播和反向传播.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9471748d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
